#!/usr/bin/env python3
from pathlib import Path
import json


defaults = {
    "lmd": 0.9,
    "reset_expert_vfn": False,
    "use_ppo_loss": True,
    "std_from_means": True,
    "use_expert_obsnormalizer": True,
    "state_pred_num_epochs": 8,
    "deterministic_experts": False,
    "learner_buffer_size": 2048,
    "expert_buffer_size": 19200,
    "max_episode_len": 1000,
    "num_epochs": 2,
    "num_rollouts": 8,
    "expert_tgtval": "gae",
    "gamma": 0.995,
    "pret_num_epochs": 32,
    "pret_num_rollouts": 8,
    "pret_num_val_iterations": 32,
}


def create_sweep(fname, envs, env2experts_list, env2ase_sigma, algorithms, learner_pis, seeds, pggae=False):
    # mamba and lops-aps
    lines = []
    for seed in seeds:
        for env in envs:
            env_name = f'dmc:{env}-v1'
            print('env', env, 'env_name', env_name)
            for experts in env2experts_list[env]:
                for learner_pi, algorithm in zip(learner_pis, algorithms):
                    ase_sigma = env2ase_sigma[env] if algorithm == "lops-aps-ase" else 0.
                    lines.append(
                        {
                            "env_name": env_name,
                            "load_expert_step": experts,
                            "experts_info": experts,
                            "algorithm": algorithm,
                            "use_riro_for_learner_pi": learner_pi,
                            "ase_sigma": ase_sigma,
                            "seed": seed,
                            **defaults
                        }
                    )
            if pggae:
                # pg-gae
                lines.append(
                    {
                        "env_name": env_name,
                        "load_expert_step": [0],
                        "algorithm": "pg-gae",
                        "use_riro_for_learner_pi": "none",
                        "ase_sigma": 0,
                        "seed": seed,
                        **defaults
                    }
                )

    json_text = [json.dumps(line, sort_keys=True) for line in lines]
    print(f'{len(json_text)} lines to {fname}')
    with open(fname, 'w') as f:
        f.write('\n'.join(json_text))


if __name__ == '__main__':
    import sys
    this_file_name = sys.argv[0]

    # Variables to sweep over
    # envs = ['Cheetah-run', 'Walker-walk', 'Pendulum-swingup', 'Cartpole-swingup']
    envs = ['Cheetah-run', 'Walker-walk', 'Pendulum-swingup', 'Cartpole-swingup']
    env2experts_list = {
        'Cheetah-run': [[100], [100, 70], [100, 70, 40], [100, 70, 40, 20]],
        'Walker-walk': [[190, 150, 100, 80], [150, 100, 80, 50], [130, 100, 80, 40]],
        'Pendulum-swingup': [[200], [200, 150], [200, 150, 100], [200, 150, 100, 50], [200], [150], [100], [50]],
        'Cartpole-swingup': [[400, 300, 200, 40], [400, 140, 80], [400, 160, 60]]}
    env2ase_sigma = {
        'Cheetah-run': 2.5,
        'Walker-walk': 10,
        'Pendulum-swingup': 0.25,
        'Cartpole-swingup': 0.25,  # <-- We should run a sweep to find out a good value for this
    }

    seeds = [i for i in range(3)]
    learner_pis = ['none', 'all', 'all']
    algorithms = ['mamba', 'lops-aps', 'lops-aps-ase']


    fname = Path(this_file_name).stem + '.jsonl'
    create_sweep(fname, envs, env2experts_list, env2ase_sigma, algorithms, learner_pis, seeds, pggae=True)
